import scipy.stats as stats
import torch
import numpy as np

from .mnist import get_mnist_data
from .fashionmnist import get_fashion_mnist_data
from .cifar import get_cifar_10_data, get_cifar_100_data


def load_data(dset_name, batchsize, datasize, num_classes, train=True, shuffle=None, random_augment_train=True, imbalance=True, permutation=None):
    dset_map = {
      "fmnist": get_fashion_mnist_data,
      "mnist": get_mnist_data,
      "cifar10": get_cifar_10_data,
      "cifar100": get_cifar_100_data
    }
    data = dset_map[dset_name]("./data", train=train, download=True, random_augment_train=random_augment_train)
    if shuffle is None:
       shuffle = train
    
    if (dset_name=="mnist" or dset_name=="fmnist"):
        if permutation:
            for i in range(data.targets.shape[0]):
                #print("before"+str(data.targets[i].item()))
                target = data.targets[i].item()
                data.targets[i] = torch.tensor(permutation[target])
                #print("after"+str(data.targets[i].item()))
        if train:
            #modify the dataset
            counts = [0]*10
            uppers = [5000 for i in range(10)] if (not imbalance and train) else [10000 for i in range(10)]
            index_list = []
            index_per_class = [[] for _ in range(num_classes)]
            for i in range(data.targets.shape[0]):
                target = data.targets[i].item()
                if counts[target]<uppers[target] and target<num_classes:
                    counts[target] += 1
                    index_list.append(i)
                    index_per_class[target].append(i)

            data_subset = torch.utils.data.Subset(data, index_list)
            loaders = [torch.utils.data.DataLoader(data_subset, batch_size=batchsize, shuffle=shuffle)]
            for i in range(num_classes):
                data_subset = torch.utils.data.Subset(data, index_per_class[i])
                loaders.append(torch.utils.data.DataLoader(data_subset, batch_size=batchsize, shuffle=shuffle))
            return loaders
        else:
            index_list = []
            for i in range(data.targets.shape[0]):
                target = data.targets[i].item()
                if target<num_classes:
                    index_list.append(i)
            data_subset = torch.utils.data.Subset(data, index_list)
            return torch.utils.data.DataLoader(data_subset, batch_size=batchsize, shuffle=shuffle)
    
    return torch.utils.data.DataLoader(data, batch_size=batchsize, shuffle=shuffle)


def corrupt_dataset_labels(loader, random_label_proportion):
    targets = loader.dataset.targets
    use_random_label = torch.LongTensor(stats.bernoulli.rvs(random_label_proportion,
                                                            size=targets.shape[0]).astype(np.long))
    new_targets = use_random_label * torch.randint(0, 10, targets.shape) + (1 - use_random_label) * targets
    loader.dataset.targets = new_targets
